from pyspark.sql.session import SparkSession
from pyspark.sql.functions import *
from pyspark.ml.classification import LogisticRegression

spark = SparkSession.builder.getOrCreate()

# 1、读取数据
image_data = spark.read.format("libsvm").load("../../data/image_data")

# image_data.show(truncate=False)

# 2、将数据切分成训练集和测试集
train, test = image_data.randomSplit([0.7, 0.3])

# 3、选择算法
lr = LogisticRegression()

# 4、将训练集带入蒜贩训练模型
model = lr.fit(train)

# 5、将测试集带入模型测试模型准确率
test_predict = model.transform(test)

# 6、计算准确率
acc = test_predict.where(col("prediction") == col("label")).count() / test_predict.count()

print(f"准确率：{acc}")

# 7、保存模型
model.save("../../data/image_model")
